import asyncio
import time

from typing import List, Tuple, Any
from collections import OrderedDict

from wintersweet.asyncs.utils.futures import RunAbleFuture
from wintersweet.framework.fastapi.response import HTTP429Response
from wintersweet.utils.base import Utils
from wintersweet.framework.fastapi.limiter.util import get_remote_addr
from wintersweet.asyncs.pool.redis import redis_pool_manager


class BaseLimit:

    def __init__(self, forbid_response_factory=None):
        self._forbid_response_factory = forbid_response_factory or HTTP429Response

    @property
    def forbid_response_factory(self):
        return self._forbid_response_factory

    async def is_forbid(self, request) -> Tuple[bool, Any]:
        raise NotImplementedError()


class LimitGroup:

    def __init__(self, limits: List[BaseLimit]):

        self._limits = limits

    @property
    def limits(self):
        return self._limits

    async def check_request_limited(self, request):
        for limit in self._limits:
            result, data = await limit.is_forbid(request)
            if result:
                return limit.forbid_response_factory(data)
        else:
            return False


class KeyRateLimit(BaseLimit):
    """
    频率控制器，默认通过客户端ip作为用户key，支持key自定义，支持redis计数，支持ntp客户端
    """
    def __init__(self,
                 times=5,
                 seconds=60,
                 forbid_seconds=0,
                 key=None,
                 name=None,
                 max_tags_size=None,
                 redis=None,
                 ntp=None,
                 forbid_response_factory=None,
                 ):
        super().__init__(forbid_response_factory)
        self._name = name
        self._times = times
        self._seconds = seconds
        self._forbid_seconds = forbid_seconds
        self._key = key or get_remote_addr
        self._max_tags_size = max_tags_size or 0xffff
        self._redis = redis
        self._ntp = ntp
        self._counter = {}
        self._forbidden = {}
        if ntp:
            assert hasattr(ntp, 'timestamp'), '"ntp" client must has attr "timestamp"'

    @property
    def times(self):
        return self._times

    @property
    def name(self):
        return self._times

    @property
    def seconds(self):
        return self._seconds

    @property
    def forbid_seconds(self):
        return self._forbid_seconds

    async def is_forbid(self, request) -> Tuple[bool, Any]:

        timestamp = time.time() if not self._ntp else self._ntp.timestamp
        tag = self._key(request)
        if self._redis is None:

            return self._local_forbid(timestamp, tag)
        else:
            return await self._distributed_forbid(timestamp, tag)

    async def _distributed_forbid(self, timestamp, tag):

        async with redis_pool_manager.get_client(self._redis) as cache:
            forbid_key = f'FORBID_{self.name}_{tag}'
            forbid_time = await cache.get(forbid_key)
            if forbid_time and float(forbid_time) > timestamp:
                return True, f'Next visit time: {Utils.stamp2time(int(float(forbid_time)))}'
            else:
                await cache.delete(forbid_key)

            per = timestamp // self._seconds
            incr_key = f'INCR_{self.name}_{tag}_{per}'
            count = await cache.incrby(incr_key, 1)
            if count == 1:
                await cache.expire(incr_key, self.seconds)

            next_visit_time = timestamp + self._forbid_seconds
            if count > self._times:
                set_res = await cache.setnx(forbid_key, next_visit_time)
                if set_res:
                    await cache.expire(forbid_key, self._forbid_seconds)
                else:
                    next_visit_time = await cache.get(forbid_key)

                return True, f'Next visit time: {Utils.stamp2time(int(float(next_visit_time)))}'

        return False, None

    def _local_forbid(self, timestamp, key):

        if key in self._forbidden:
            forbid_time = self._forbidden[key]
            if forbid_time < timestamp:
                self._forbidden.pop(key)
            else:
                return True, f'Next visit time: {Utils.stamp2time(int(forbid_time))}'

        per = timestamp // self._seconds

        if key not in self._counter:

            self._counter[key] = {per: 1}
        else:

            if per not in self._counter[key]:

                self._counter[key].clear()
                self._counter[key][per] = 1

            else:

                self._counter[key][per] += 1

        # tag溢出
        next_per = per + 1

        if len(self._counter) >= self._max_tags_size:

            for _tag in list(self._counter.keys()):

                pers = self._counter[_tag]

                if per not in pers and next_per not in pers:
                    self._counter.pop(_tag)

        if self._counter[key][per] > self._times:

            if self._forbid_seconds:

                next_visit_time = timestamp + self._forbid_seconds
                self._forbidden[key] = next_visit_time

                # tag溢出
                if len(self._forbidden) >= self._max_tags_size:

                    for tag in list(self._forbidden.keys()):

                        if self._forbidden[tag] <= timestamp:

                            self._forbidden.pop(tag)
            else:
                next_visit_time = self._seconds * (per + 1)

            return True, f'Next visit time: {Utils.stamp2time(int(next_visit_time))}'

        else:

            return False, None


class APIRateLimit:

    def __init__(self, running_size=10, waiting_size=10, forbid_response_factory=None, task_factory=None):

        self._running_size = running_size
        self._waiting_size = waiting_size
        self._running_tasks = OrderedDict()
        self._waiting_tasks = OrderedDict()
        self._forbid_response_factory = forbid_response_factory or HTTP429Response
        self._task_factory = task_factory or RunAbleFuture

    @property
    def forbid_response_factory(self):
        return self._forbid_response_factory

    @property
    def running_size(self):
        return self._running_size

    @property
    def waiting_size(self):
        return self._waiting_size

    def _create_task(self, func, args, kwargs):

        tag = Utils.uuid.uuid1().hex

        task = self._task_factory(tag, func, func_args=args, func_kwargs=kwargs)

        task.add_done_callback(self._remove)

        return tag, task

    def append(self, func, args, kwargs):
        result = None

        if len(self._running_tasks) >= self._running_size:

            if len(self._waiting_tasks) < self._waiting_size:

                tag, task = self._create_task(func, args, kwargs)

                self._waiting_tasks[tag] = task

                result = task

        else:
            tag, task = self._create_task(func, args, kwargs)

            self._running_tasks[tag] = task

            result = task

            asyncio.create_task(task.run())

        return result

    def _remove(self, f: RunAbleFuture):

        self._running_tasks.pop(f.tag)

        if self._waiting_tasks:

            tag, task = self._waiting_tasks.popitem(last=False)

            self._running_tasks[tag] = task

            asyncio.create_task(task.run())
